# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Exploitability measurement utils for symmetric games."""

from absl import logging  # pylint:disable=unused-import

import numpy as np
from scipy import special

from open_spiel.python.algorithms.adidas_utils.helpers import misc


def unreg_exploitability(dist, payoff_tensor):
  """Compute exploitability of dist for symmetric game.

  Args:
    dist: 1-d np.array, current estimate of nash distribution
    payoff_tensor: (>=1 x A x ... x A) np.array, payoffs for each joint action
  Returns:
    exploitability (float): payoff of best response - payoff of dist
  """
  num_players = payoff_tensor.shape[0]
  nabla = misc.pt_reduce(payoff_tensor[0], [dist] * num_players, [0])

  u_br = np.max(nabla)
  u_dist = nabla.dot(dist)

  return u_br - u_dist


def ate_exploitability(dist, payoff_tensor, p=1):
  """Compute Tsallis regularized exploitability of dist for symmetric game.

  Args:
    dist: 1-d np.array, current estimate of nash distribution
    payoff_tensor: (>=1 x A x ... x A) np.array, payoffs for each joint action
      assumed to be non-negative
    p: float in [0, 1], Tsallis entropy-regularization --> 0 as p --> 0
  Returns:
    exploitability (float): payoff of best response - payoff of dist
  """
  if payoff_tensor.min() < 0.:
    raise ValueError('payoff tensor must be non-negative')
  num_players = payoff_tensor.shape[0]
  nabla = misc.pt_reduce(payoff_tensor[0], [dist] * num_players, [0])
  if p > 0:
    power = 1./p
    s = np.linalg.norm(nabla, ord=power)
    br = (nabla / np.linalg.norm(nabla, ord=power))**power
  else:
    power = np.inf
    s = np.linalg.norm(nabla, ord=power)
    br = np.zeros_like(dist)
    maxima = (nabla == s)
    br[maxima] = 1. / maxima.sum()

  u_br = nabla.dot(br) + s / (p + 1) * (1 - np.sum(br**(p + 1)))
  u_dist = nabla.dot(dist) + s / (p + 1) * (1 - np.sum(dist**(p + 1)))

  return u_br - u_dist


def qre_exploitability(dist, payoff_tensor, temperature=0.):
  """Compute Shannon regularized exploitability of dist for symmetric game.

  Args:
    dist: 1-d np.array, current estimate of nash distribution
    payoff_tensor: (>=1 x A x ... x A) np.array, payoffs for each joint action
      assumed to be non-negative
    temperature: non-negative float
  Returns:
    exploitability (float): payoff of best response - payoff of dist
  """
  num_players = payoff_tensor.shape[0]
  nabla = misc.pt_reduce(payoff_tensor[0], [dist] * num_players, [0])
  if temperature > 0:
    br = special.softmax(nabla / temperature)
  else:
    br = np.zeros_like(dist)
    maxima = (nabla == np.max(nabla))
    br[maxima] = 1. / maxima.sum()

  u_br = nabla.dot(br) + temperature * special.entr(br).sum()
  u_dist = nabla.dot(dist) + temperature * special.entr(dist).sum()

  return u_br - u_dist
